Diffusion Mapping¶

In [1]:
import numpy as np
x = [-10,0,10]
b = .03

# Create a 1000 x 3 matrix
output_matrix = np.zeros((1000, 3))

# Generate random values for each value in x
for i, value in enumerate(x):
    output_matrix[0, i] = np.random.normal(x[i], np.sqrt(b), 1)

for i in range(1, 1000):
    y = output_matrix[i-1, :]
    output_matrix[i, 0] = np.random.normal(np.sqrt(1 - b) * y[0], np.sqrt(b), 1)
    output_matrix[i, 1] = np.random.normal(np.sqrt(1 - b) * y[1], np.sqrt(b), 1)
    output_matrix[i, 2] = np.random.normal(np.sqrt(1 - b) * y[2], np.sqrt(b), 1)

import matplotlib.pyplot as plt

# Create a time array from 0 to 999
time = range(1000)

# Plot each column of output_matrix
for i in range(3):
    plt.plot(time, output_matrix[:, i], label=f'Value {i+1}')

# Add labels and legend
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()

# Show the plot
plt.show()
    
No description has been provided for this image
In [2]:
import numpy as np
tps = [1, 10, 25, 50, 100, 200]

# Compute (1 - b)^tps for each value in tps
bt = np.power(1 - b, tps)

import matplotlib.pyplot as plt

import seaborn as sns

# Create a 2 x 3 grid of subplots
fig, axs = plt.subplots(2, 3, figsize=(12, 8))

# Iterate over each tps value
for i, tp in enumerate(tps):
    
    z = output_matrix[tp-1, :]
    
    # Compute the mean and variance
    mean = np.sqrt(bt[i]) * z
    variance = 1 - bt[i]
    
    # Generate normal random numbers
    samples1 = np.random.normal(mean[0], np.sqrt(variance), 10000)
    samples2 = np.random.normal(mean[1], np.sqrt(variance), 10000)
    samples3 = np.random.normal(mean[2], np.sqrt(variance), 10000)
    
    # Plot the density
    row = i // 3
    col = i % 3
    sns.kdeplot(samples1, ax=axs[row, col])
    sns.kdeplot(samples2, ax=axs[row, col])
    sns.kdeplot(samples3, ax=axs[row, col])
    axs[row, col].set_title(f'Step = {tp}')

# Adjust the spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()
No description has been provided for this image

Simple Diffusion for Sampling from 9 Component Mixture¶

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# 1) Create your 2D 9‑Gaussian mixture dataset
# -----------------------------------------------------------------------------
centers = torch.tensor([[x, y] for x in (-2, 0, 2) for y in (-2, 0, 2)],
                       dtype=torch.float32)
n_samples = 10000

# sample mixture indices uniformly
idx = torch.randint(0, 9, (n_samples,))
# draw x0 ~ N(center[idx], 0.1^2 I)
x0 = centers[idx] + torch.randn(n_samples, 2) * 0.1

dataset = TensorDataset(x0)
loader  = DataLoader(dataset, batch_size=256, shuffle=True)
In [2]:
import seaborn as sns

import matplotlib.pyplot as plt

# Extract x and y coordinates from x0
x = x0[:, 0].numpy()
y = x0[:, 1].numpy()

plt.figure(figsize=(6,6))
sns.kdeplot(x=x, y=y, fill=True, cmap="viridis")
plt.title("Density Plot for x0")
plt.xlabel("x coordinate")
plt.ylabel("y coordinate")
plt.show()
No description has been provided for this image
In [9]:
# -----------------------------------------------------------------------------
# 2) Diffusion schedule & forward q_sample
# -----------------------------------------------------------------------------
T       = 500
beta    = torch.linspace(1e-4, .05, T)
alpha   = 1 - beta
alphac  = torch.cumprod(alpha, dim=0)

def q_sample(x0, t, noise):
    """
    x_t = sqrt(alpha_bar_t)*x0 + sqrt(1−alpha_bar_t)*noise
    where
      alpha_bar_t = alphac[t]  shape [batch_size]
    """
    # grab alpha_bar for each sample and unsqueeze so it’s [B,1]
    a_bar = alphac[t].unsqueeze(-1)       # now shape (B,1)
    # and broadcast multiply against the 2‑D points
    return torch.sqrt(a_bar) * x0 + torch.sqrt(1.0 - a_bar) * noise
In [10]:
import matplotlib.pyplot as plt

# Define the t values for which we'll generate q_sample images
t_values = [0, 100, 200, 300, 400]

# Create a constant noise tensor with value 0.01 (same shape as x0)
noise_const = torch.full_like(x0, 0)

# Create a subplot with one row and 5 columns
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for ax, t_val in zip(axes, t_values):
    # Generate diffused samples at time t_val using q_sample
    xt = q_sample(x0, torch.tensor(t_val), noise_const)
    # Convert to numpy for plotting
    xt_np = xt.detach().numpy()
    # Create a scatter plot of the diffused points
    ax.scatter(xt_np[:, 0], xt_np[:, 1], s=1, alpha=0.6)
    ax.set_title(f"t = {t_val}")
    ax.set_xlim([-3, 3])
    ax.set_ylim([-3, 3])
    ax.set_aspect('equal')

plt.tight_layout()
plt.show()
No description has been provided for this image
In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------------------------------------------------
# 1) Sinusoidal positional embedding for timesteps
# -------------------------------------------------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        """
        dim: total dimension of the output embedding (must be even)
        """
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: tensor of shape [B] with integer timesteps in [0, T-1]
        returns: [B, dim] sinusoidal embedding
        """
        half = self.dim // 2
        # frequencies: exp(-log(10000)*(0..half-1)/(half-1))
        freqs = torch.exp(
            -math.log(10000) * torch.arange(half, device=t.device).float() / (half - 1)
        )  # [half]
        args = t.float().unsqueeze(1) * freqs.unsqueeze(0)  # [B, half]
        emb = torch.cat([args.sin(), args.cos()], dim=-1)   # [B, dim]
        return emb

# -------------------------------------------------------------------
# 2) Score network with explicit time embedding
# -------------------------------------------------------------------
class ScoreNet2D(nn.Module):
    def __init__(self, temb_dim=64, hidden_dim=128):
        """
        temb_dim: dimension of the sinusoidal time embedding
        hidden_dim: hidden layer width
        """
        super().__init__()
        # map t -> sinusoidal embedding -> project
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(temb_dim),
            nn.Linear(temb_dim, temb_dim),
            nn.SiLU(),
        )
        # main MLP: input is (x:2 dims) + (temb_dim)
        self.net = nn.Sequential(
            nn.Linear(2 + temb_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),      nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, x, t):
        """
        x: [B, 2]  2D points
        t: [B]     integer timesteps in {0,...,T-1}
        """
        # 1) normalize t to [0,1]
        t = t.float() / (T - 1)
        # 2) get time embedding
        temb = self.time_mlp(t)          # [B, temb_dim]
        # 3) concatenate and predict noise
        inp = torch.cat([x, temb], dim=1)  # [B, 2+temb_dim]
        return self.net(inp)               # [B, 2]
In [12]:
# -----------------------------------------------------------------------------
# 4) Reverse (ancestral) sampler
# -----------------------------------------------------------------------------
@torch.no_grad()
def sample(model, n_samples, device):
    model.eval()
    x = torch.randn(n_samples, 2, device=device)
    for i in reversed(range(T)):
        t = torch.full((n_samples,), i, device=device, dtype=torch.long)
        eps = model(x, t).clamp(-5, 5)
        b = beta[i]
        a = alpha[i]
        atil = alphac[i]

        # ancestral update
        mean = (1/math.sqrt(a)) * (x - (b / math.sqrt(1-atil)) * eps)
        if i > 0:
            noise = torch.randn_like(x)
            x = mean + math.sqrt(b) * noise
        else:
            x = mean
    return x.cpu()
In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
beta   = beta.to(device)
alpha  = alpha.to(device)
alphac = alphac.to(device)
model  = ScoreNet2D(temb_dim=64, hidden_dim=128).to(device)
opt    = torch.optim.Adam(model.parameters(), lr=1e-3)
In [19]:
epochs = 500
for epoch in range(epochs):
    running_loss = 0.0
    for (x0_batch,) in loader:
        x0_batch = x0_batch.to(device)
        b = x0_batch.size(0)
        t = torch.randint(0, T, (b,), device=device)
        noise = torch.randn_like(x0_batch)

        xt = q_sample(x0_batch, t, noise)
        eps_pred = model(xt, t)
        loss = F.mse_loss(eps_pred, noise)

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        running_loss += loss.item() * b


    if (epoch+1) % 25 == 0 or epoch == 0:
        avg_loss = running_loss / n_samples
        print(f"Epoch {epoch+1:3d}/{epochs}, Loss: {avg_loss:.4f}")
        gen = sample(model, 10000, device)  # [10000,2]
        gen = gen.numpy()
        plt.figure(figsize=(6,6))
        plt.hist2d(gen[:,0], gen[:,1], bins=100)
        plt.title(f"Density of Generated Samples at Epoch {epoch+1}")
        plt.axis('equal')
        plt.show()
Epoch   1/500, Loss: 0.6537
No description has been provided for this image
Epoch  25/500, Loss: 0.2749
No description has been provided for this image
Epoch  50/500, Loss: 0.2534
No description has been provided for this image
Epoch  75/500, Loss: 0.2493
No description has been provided for this image
Epoch 100/500, Loss: 0.2259
No description has been provided for this image
Epoch 125/500, Loss: 0.2367
No description has been provided for this image
Epoch 150/500, Loss: 0.2162
No description has been provided for this image
Epoch 175/500, Loss: 0.2207
No description has been provided for this image
Epoch 200/500, Loss: 0.2114
No description has been provided for this image
Epoch 225/500, Loss: 0.2090
No description has been provided for this image
Epoch 250/500, Loss: 0.2205
No description has been provided for this image
Epoch 275/500, Loss: 0.2301
No description has been provided for this image
Epoch 300/500, Loss: 0.2164
No description has been provided for this image
Epoch 325/500, Loss: 0.2218
No description has been provided for this image
Epoch 350/500, Loss: 0.2287
No description has been provided for this image
Epoch 375/500, Loss: 0.2175
No description has been provided for this image
Epoch 400/500, Loss: 0.2117
No description has been provided for this image
Epoch 425/500, Loss: 0.2138
No description has been provided for this image
Epoch 450/500, Loss: 0.2094
No description has been provided for this image
Epoch 475/500, Loss: 0.2280
No description has been provided for this image
Epoch 500/500, Loss: 0.2273
No description has been provided for this image

Diffusion Model w/ Attention for CelebA¶

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from tqdm import tqdm

# -----------------------------------------------------------------------------
# 1) Noise schedule + sampling helper
# -----------------------------------------------------------------------------
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

class DiffusionSchedule:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        self.betas = linear_beta_schedule(timesteps).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod    = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_acumprod = torch.sqrt(1 - self.alphas_cumprod)

    def q_sample(self, x_start, t, noise):
        """
        Diffuse x_start to x_t by adding noise at step t:
           x_t = sqrt(ᾱ_t) * x_0 + sqrt(1 - ᾱ_t) * ε
        """
        return (
            self.sqrt_alphas_cumprod[t].view(-1,1,1,1) * x_start +
            self.sqrt_one_minus_acumprod[t].view(-1,1,1,1) * noise
        )
In [2]:
@torch.no_grad()
def p_sample_loop(model, diffusion, shape):
    model.eval()
    x = torch.randn(shape, device=device)

    for i in reversed(range(diffusion.timesteps)):
        # 1) prepare timestep tensor and predict ε
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)
        eps_pred = model(x, t.float() / diffusion.timesteps)
        # clamp huge predictions
        eps_pred = eps_pred.clamp(-5.0, 5.0)

        # 2) grab scalars for this step
        beta_t      = diffusion.betas[i]
        alpha_t     = diffusion.alphas[i]
        alpha_bar_t = diffusion.alphas_cumprod[i]

        # 3) stable sqrt’s
        sqrt_alpha_t     = torch.sqrt(alpha_t).clamp(min=1e-5)
        sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t).clamp(min=1e-5)

        # 4) ancestral mean: 
        #    μ = (1/√α_t) [ x_t − (β_t / √(1−ᾱ_t)) · ε_pred ]
        mean = (1.0 / sqrt_alpha_t) * (
            x - (beta_t / sqrt_one_minus_ab) * eps_pred
        )

        # 5) sample from p(x_{t−1}|x_t)
        if i > 0:
            noise = torch.randn_like(x)
            sigma_t = torch.sqrt(beta_t)
            x = mean + sigma_t * noise
        else:
            x = mean

    model.train()
    return x.clamp(-1, 1)
In [3]:
# -----------------------------------------------------------------------------
# 2) Time embedding, ResBlock, MHSA2d, Down/Up, UNetAttention
# -----------------------------------------------------------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(10000) * torch.arange(half, device=t.device).float() / (half - 1)
        )
        args = t.float()[:, None] * freqs[None]
        return torch.cat([args.sin(), args.cos()], dim=-1)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.0):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.emb_proj = nn.Linear(temb_dim, out_ch) if temb_dim is not None else None
        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, temb):
        h = self.conv1(F.silu(self.norm1(x)))
        if self.emb_proj is not None:
            h = h + self.emb_proj(temb)[:, :, None, None]
        h = self.conv2(self.dropout(F.silu(self.norm2(h))))
        return h + self.res_conv(x)
In [4]:
class MHSA2d(nn.Module):
    def __init__(self, channels, num_heads=4, dropout=0.0):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.attn = nn.MultiheadAttention(
            embed_dim=channels,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=False
        )
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B,C,H,W = x.shape
        h = self.norm(x).view(B, C, H*W).permute(2,0,1)   # (S,B,C)
        out, _ = self.attn(h, h, h)
        out = out.permute(1,2,0).view(B,C,H,W)
        return x + self.proj(out)
In [5]:
class Downsample(nn.Module):
    def __init__(self, ch): super().__init__(); self.conv = nn.Conv2d(ch,ch,3,2,1)
    def forward(self,x): return self.conv(x)

class Upsample(nn.Module):
    def __init__(self,ch): super().__init__(); self.conv = nn.Conv2d(ch,ch,3,1,1)
    def forward(self,x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return self.conv(x)
In [6]:
class UNetAttention(nn.Module):
    def __init__(
        self,
        in_ch=3,
        base_ch=128,
        chan_mults=(1,1,2,2,4,4),
        num_res_blocks=2,
        temb_dim=512,
        dropout=0.0
    ):
        super().__init__()
        self.num_res_blocks = num_res_blocks

        # 1) Time‐step embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(temb_dim//2),
            nn.Linear(temb_dim//2, temb_dim),
            nn.SiLU(),
            nn.Linear(temb_dim, temb_dim),
        )

        # 2) Initial conv
        ch = base_ch
        self.init_conv = nn.Conv2d(in_ch, ch, 3, padding=1)

        # 3) Down path: build blocks & record skip channels
        self.down_blocks  = nn.ModuleList()
        self.down_samples = nn.ModuleList()
        skip_channels = []
        for i, mult in enumerate(chan_mults):
            out_ch = base_ch * mult
            for _ in range(num_res_blocks):
                self.down_blocks.append( ResBlock(ch, out_ch, temb_dim, dropout) )
                skip_channels.append(out_ch)    # record this for the up path
                ch = out_ch
            if i < len(chan_mults) - 1:
                self.down_samples.append( Downsample(ch) )

        # 4) Bottleneck
        self.mid1     = ResBlock(ch, ch, temb_dim, dropout)
        self.mid_attn = MHSA2d(ch, num_heads=4, dropout=dropout)
        self.mid2     = ResBlock(ch, ch, temb_dim, dropout)

        # 5) Up path: consume skip_channels in reverse
        self.up_blocks  = nn.ModuleList()
        self.up_samples = nn.ModuleList()
        skip_idx = len(skip_channels) - 1
        for i, mult in reversed(list(enumerate(chan_mults))):
            out_ch = base_ch * mult
            # add an upsample layer if we're not at the very first (lowest) resolution
            if i < len(chan_mults) - 1:
                self.up_samples.append( Upsample(ch) )
            # for each residual block in this stage, cat with a skip from down
            for _ in range(num_res_blocks):
                skip_ch = skip_channels[skip_idx]
                skip_idx -= 1
                self.up_blocks.append( ResBlock(ch + skip_ch, out_ch, temb_dim, dropout) )
                ch = out_ch

        # 6) Final norm → activation → 1×1 conv
        self.final_norm = nn.GroupNorm(8, ch)
        self.final_conv = nn.Conv2d(ch, in_ch, 1)

    def forward(self, x, t):
        # time embed
        temb = self.time_mlp(t)

        # down
        h = self.init_conv(x)
        skips = []
        bi = 0
        for block in self.down_blocks:
            h = block(h, temb)
            skips.append(h)
            # after each group of num_res_blocks, do a downsample if available
            if (bi + 1) % self.num_res_blocks == 0 and (bi // self.num_res_blocks) < len(self.down_samples):
                h = self.down_samples[bi // self.num_res_blocks](h)
            bi += 1

        # bottleneck
        h = self.mid1(h, temb)
        h = self.mid_attn(h)
        h = self.mid2(h, temb)

        # 4) up path (fixed)
        bi = 0
        for block in self.up_blocks:
            # only upsample at the start of each new resolution,
            # and never at bi=0 (i.e. before the first two blocks).
            if bi % self.num_res_blocks == 0 and bi > 0:
                us_idx = bi // self.num_res_blocks - 1
                h = self.up_samples[us_idx](h)

            skip = skips.pop()
            h = block(torch.cat([h, skip], dim=1), temb)
            bi += 1

        # 5) final norm → SiLU → conv
        h = F.silu(self.final_norm(h))
        return self.final_conv(h)
In [7]:
# -----------------------------------------------------------------------------
# 3) Data, model, optimizer, schedule (CelebA subset @ 128×128)
# -----------------------------------------------------------------------------
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset

class RawImageDataset(Dataset):
    def __init__(self, root, transform=None, exts=(".png",".jpg",".jpeg")):
        self.paths = [p for p in Path(root).iterdir() if p.suffix.lower() in exts]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3),
])

train_ds = RawImageDataset("subset_images", transform=transform)
train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# model & diffusion
model     = UNetAttention().to(device)
diffusion = DiffusionSchedule(timesteps=500)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
In [8]:
from torchinfo import summary

# Create dummy inputs: a dummy image tensor and a dummy time tensor.
# The model expects x with shape (B,3,64,64) and t with shape (B,)
dummy_x = torch.randn(1, 3, 64, 64, device=device)
dummy_t = torch.tensor([0.0], device=device)  # a dummy timestep (scaled value)

# Print the model summary
summary(model, input_data=(dummy_x, dummy_t))
Out[8]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
UNetAttention                            [1, 3, 64, 64]            --
├─Sequential: 1-1                        [1, 512]                  --
│    └─SinusoidalPosEmb: 2-1             [1, 256]                  --
│    └─Linear: 2-2                       [1, 512]                  131,584
│    └─SiLU: 2-3                         [1, 512]                  --
│    └─Linear: 2-4                       [1, 512]                  262,656
├─Conv2d: 1-2                            [1, 128, 64, 64]          3,584
├─ModuleList: 1-13                       --                        (recursive)
│    └─ResBlock: 2-5                     [1, 128, 64, 64]          --
│    │    └─GroupNorm: 3-1               [1, 128, 64, 64]          256
│    │    └─Conv2d: 3-2                  [1, 128, 64, 64]          147,584
│    │    └─Linear: 3-3                  [1, 128]                  65,664
│    │    └─GroupNorm: 3-4               [1, 128, 64, 64]          256
│    │    └─Dropout: 3-5                 [1, 128, 64, 64]          --
│    │    └─Conv2d: 3-6                  [1, 128, 64, 64]          147,584
│    │    └─Identity: 3-7                [1, 128, 64, 64]          --
│    └─ResBlock: 2-6                     [1, 128, 64, 64]          --
│    │    └─GroupNorm: 3-8               [1, 128, 64, 64]          256
│    │    └─Conv2d: 3-9                  [1, 128, 64, 64]          147,584
│    │    └─Linear: 3-10                 [1, 128]                  65,664
│    │    └─GroupNorm: 3-11              [1, 128, 64, 64]          256
│    │    └─Dropout: 3-12                [1, 128, 64, 64]          --
│    │    └─Conv2d: 3-13                 [1, 128, 64, 64]          147,584
│    │    └─Identity: 3-14               [1, 128, 64, 64]          --
├─ModuleList: 1-12                       --                        (recursive)
│    └─Downsample: 2-7                   [1, 128, 32, 32]          --
│    │    └─Conv2d: 3-15                 [1, 128, 32, 32]          147,584
├─ModuleList: 1-13                       --                        (recursive)
│    └─ResBlock: 2-8                     [1, 128, 32, 32]          --
│    │    └─GroupNorm: 3-16              [1, 128, 32, 32]          256
│    │    └─Conv2d: 3-17                 [1, 128, 32, 32]          147,584
│    │    └─Linear: 3-18                 [1, 128]                  65,664
│    │    └─GroupNorm: 3-19              [1, 128, 32, 32]          256
│    │    └─Dropout: 3-20                [1, 128, 32, 32]          --
│    │    └─Conv2d: 3-21                 [1, 128, 32, 32]          147,584
│    │    └─Identity: 3-22               [1, 128, 32, 32]          --
│    └─ResBlock: 2-9                     [1, 128, 32, 32]          --
│    │    └─GroupNorm: 3-23              [1, 128, 32, 32]          256
│    │    └─Conv2d: 3-24                 [1, 128, 32, 32]          147,584
│    │    └─Linear: 3-25                 [1, 128]                  65,664
│    │    └─GroupNorm: 3-26              [1, 128, 32, 32]          256
│    │    └─Dropout: 3-27                [1, 128, 32, 32]          --
│    │    └─Conv2d: 3-28                 [1, 128, 32, 32]          147,584
│    │    └─Identity: 3-29               [1, 128, 32, 32]          --
├─ModuleList: 1-12                       --                        (recursive)
│    └─Downsample: 2-10                  [1, 128, 16, 16]          --
│    │    └─Conv2d: 3-30                 [1, 128, 16, 16]          147,584
├─ModuleList: 1-13                       --                        (recursive)
│    └─ResBlock: 2-11                    [1, 256, 16, 16]          --
│    │    └─GroupNorm: 3-31              [1, 128, 16, 16]          256
│    │    └─Conv2d: 3-32                 [1, 256, 16, 16]          295,168
│    │    └─Linear: 3-33                 [1, 256]                  131,328
│    │    └─GroupNorm: 3-34              [1, 256, 16, 16]          512
│    │    └─Dropout: 3-35                [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-36                 [1, 256, 16, 16]          590,080
│    │    └─Conv2d: 3-37                 [1, 256, 16, 16]          33,024
│    └─ResBlock: 2-12                    [1, 256, 16, 16]          --
│    │    └─GroupNorm: 3-38              [1, 256, 16, 16]          512
│    │    └─Conv2d: 3-39                 [1, 256, 16, 16]          590,080
│    │    └─Linear: 3-40                 [1, 256]                  131,328
│    │    └─GroupNorm: 3-41              [1, 256, 16, 16]          512
│    │    └─Dropout: 3-42                [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-43                 [1, 256, 16, 16]          590,080
│    │    └─Identity: 3-44               [1, 256, 16, 16]          --
├─ModuleList: 1-12                       --                        (recursive)
│    └─Downsample: 2-13                  [1, 256, 8, 8]            --
│    │    └─Conv2d: 3-45                 [1, 256, 8, 8]            590,080
├─ModuleList: 1-13                       --                        (recursive)
│    └─ResBlock: 2-14                    [1, 256, 8, 8]            --
│    │    └─GroupNorm: 3-46              [1, 256, 8, 8]            512
│    │    └─Conv2d: 3-47                 [1, 256, 8, 8]            590,080
│    │    └─Linear: 3-48                 [1, 256]                  131,328
│    │    └─GroupNorm: 3-49              [1, 256, 8, 8]            512
│    │    └─Dropout: 3-50                [1, 256, 8, 8]            --
│    │    └─Conv2d: 3-51                 [1, 256, 8, 8]            590,080
│    │    └─Identity: 3-52               [1, 256, 8, 8]            --
│    └─ResBlock: 2-15                    [1, 256, 8, 8]            --
│    │    └─GroupNorm: 3-53              [1, 256, 8, 8]            512
│    │    └─Conv2d: 3-54                 [1, 256, 8, 8]            590,080
│    │    └─Linear: 3-55                 [1, 256]                  131,328
│    │    └─GroupNorm: 3-56              [1, 256, 8, 8]            512
│    │    └─Dropout: 3-57                [1, 256, 8, 8]            --
│    │    └─Conv2d: 3-58                 [1, 256, 8, 8]            590,080
│    │    └─Identity: 3-59               [1, 256, 8, 8]            --
├─ModuleList: 1-12                       --                        (recursive)
│    └─Downsample: 2-16                  [1, 256, 4, 4]            --
│    │    └─Conv2d: 3-60                 [1, 256, 4, 4]            590,080
├─ModuleList: 1-13                       --                        (recursive)
│    └─ResBlock: 2-17                    [1, 512, 4, 4]            --
│    │    └─GroupNorm: 3-61              [1, 256, 4, 4]            512
│    │    └─Conv2d: 3-62                 [1, 512, 4, 4]            1,180,160
│    │    └─Linear: 3-63                 [1, 512]                  262,656
│    │    └─GroupNorm: 3-64              [1, 512, 4, 4]            1,024
│    │    └─Dropout: 3-65                [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-66                 [1, 512, 4, 4]            2,359,808
│    │    └─Conv2d: 3-67                 [1, 512, 4, 4]            131,584
│    └─ResBlock: 2-18                    [1, 512, 4, 4]            --
│    │    └─GroupNorm: 3-68              [1, 512, 4, 4]            1,024
│    │    └─Conv2d: 3-69                 [1, 512, 4, 4]            2,359,808
│    │    └─Linear: 3-70                 [1, 512]                  262,656
│    │    └─GroupNorm: 3-71              [1, 512, 4, 4]            1,024
│    │    └─Dropout: 3-72                [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-73                 [1, 512, 4, 4]            2,359,808
│    │    └─Identity: 3-74               [1, 512, 4, 4]            --
├─ModuleList: 1-12                       --                        (recursive)
│    └─Downsample: 2-19                  [1, 512, 2, 2]            --
│    │    └─Conv2d: 3-75                 [1, 512, 2, 2]            2,359,808
├─ModuleList: 1-13                       --                        (recursive)
│    └─ResBlock: 2-20                    [1, 512, 2, 2]            --
│    │    └─GroupNorm: 3-76              [1, 512, 2, 2]            1,024
│    │    └─Conv2d: 3-77                 [1, 512, 2, 2]            2,359,808
│    │    └─Linear: 3-78                 [1, 512]                  262,656
│    │    └─GroupNorm: 3-79              [1, 512, 2, 2]            1,024
│    │    └─Dropout: 3-80                [1, 512, 2, 2]            --
│    │    └─Conv2d: 3-81                 [1, 512, 2, 2]            2,359,808
│    │    └─Identity: 3-82               [1, 512, 2, 2]            --
│    └─ResBlock: 2-21                    [1, 512, 2, 2]            --
│    │    └─GroupNorm: 3-83              [1, 512, 2, 2]            1,024
│    │    └─Conv2d: 3-84                 [1, 512, 2, 2]            2,359,808
│    │    └─Linear: 3-85                 [1, 512]                  262,656
│    │    └─GroupNorm: 3-86              [1, 512, 2, 2]            1,024
│    │    └─Dropout: 3-87                [1, 512, 2, 2]            --
│    │    └─Conv2d: 3-88                 [1, 512, 2, 2]            2,359,808
│    │    └─Identity: 3-89               [1, 512, 2, 2]            --
├─ResBlock: 1-14                         [1, 512, 2, 2]            --
│    └─GroupNorm: 2-22                   [1, 512, 2, 2]            1,024
│    └─Conv2d: 2-23                      [1, 512, 2, 2]            2,359,808
│    └─Linear: 2-24                      [1, 512]                  262,656
│    └─GroupNorm: 2-25                   [1, 512, 2, 2]            1,024
│    └─Dropout: 2-26                     [1, 512, 2, 2]            --
│    └─Conv2d: 2-27                      [1, 512, 2, 2]            2,359,808
│    └─Identity: 2-28                    [1, 512, 2, 2]            --
├─MHSA2d: 1-15                           [1, 512, 2, 2]            --
│    └─GroupNorm: 2-29                   [1, 512, 2, 2]            1,024
│    └─MultiheadAttention: 2-30          [4, 1, 512]               1,050,624
│    └─Conv2d: 2-31                      [1, 512, 2, 2]            262,656
├─ResBlock: 1-16                         [1, 512, 2, 2]            --
│    └─GroupNorm: 2-32                   [1, 512, 2, 2]            1,024
│    └─Conv2d: 2-33                      [1, 512, 2, 2]            2,359,808
│    └─Linear: 2-34                      [1, 512]                  262,656
│    └─GroupNorm: 2-35                   [1, 512, 2, 2]            1,024
│    └─Dropout: 2-36                     [1, 512, 2, 2]            --
│    └─Conv2d: 2-37                      [1, 512, 2, 2]            2,359,808
│    └─Identity: 2-38                    [1, 512, 2, 2]            --
├─ModuleList: 1-27                       --                        (recursive)
│    └─ResBlock: 2-39                    [1, 512, 2, 2]            --
│    │    └─GroupNorm: 3-90              [1, 1024, 2, 2]           2,048
│    │    └─Conv2d: 3-91                 [1, 512, 2, 2]            4,719,104
│    │    └─Linear: 3-92                 [1, 512]                  262,656
│    │    └─GroupNorm: 3-93              [1, 512, 2, 2]            1,024
│    │    └─Dropout: 3-94                [1, 512, 2, 2]            --
│    │    └─Conv2d: 3-95                 [1, 512, 2, 2]            2,359,808
│    │    └─Conv2d: 3-96                 [1, 512, 2, 2]            524,800
│    └─ResBlock: 2-40                    [1, 512, 2, 2]            --
│    │    └─GroupNorm: 3-97              [1, 1024, 2, 2]           2,048
│    │    └─Conv2d: 3-98                 [1, 512, 2, 2]            4,719,104
│    │    └─Linear: 3-99                 [1, 512]                  262,656
│    │    └─GroupNorm: 3-100             [1, 512, 2, 2]            1,024
│    │    └─Dropout: 3-101               [1, 512, 2, 2]            --
│    │    └─Conv2d: 3-102                [1, 512, 2, 2]            2,359,808
│    │    └─Conv2d: 3-103                [1, 512, 2, 2]            524,800
├─ModuleList: 1-26                       --                        (recursive)
│    └─Upsample: 2-41                    [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-104                [1, 512, 4, 4]            2,359,808
├─ModuleList: 1-27                       --                        (recursive)
│    └─ResBlock: 2-42                    [1, 512, 4, 4]            --
│    │    └─GroupNorm: 3-105             [1, 1024, 4, 4]           2,048
│    │    └─Conv2d: 3-106                [1, 512, 4, 4]            4,719,104
│    │    └─Linear: 3-107                [1, 512]                  262,656
│    │    └─GroupNorm: 3-108             [1, 512, 4, 4]            1,024
│    │    └─Dropout: 3-109               [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-110                [1, 512, 4, 4]            2,359,808
│    │    └─Conv2d: 3-111                [1, 512, 4, 4]            524,800
│    └─ResBlock: 2-43                    [1, 512, 4, 4]            --
│    │    └─GroupNorm: 3-112             [1, 1024, 4, 4]           2,048
│    │    └─Conv2d: 3-113                [1, 512, 4, 4]            4,719,104
│    │    └─Linear: 3-114                [1, 512]                  262,656
│    │    └─GroupNorm: 3-115             [1, 512, 4, 4]            1,024
│    │    └─Dropout: 3-116               [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-117                [1, 512, 4, 4]            2,359,808
│    │    └─Conv2d: 3-118                [1, 512, 4, 4]            524,800
├─ModuleList: 1-26                       --                        (recursive)
│    └─Upsample: 2-44                    [1, 512, 8, 8]            --
│    │    └─Conv2d: 3-119                [1, 512, 8, 8]            2,359,808
├─ModuleList: 1-27                       --                        (recursive)
│    └─ResBlock: 2-45                    [1, 256, 8, 8]            --
│    │    └─GroupNorm: 3-120             [1, 768, 8, 8]            1,536
│    │    └─Conv2d: 3-121                [1, 256, 8, 8]            1,769,728
│    │    └─Linear: 3-122                [1, 256]                  131,328
│    │    └─GroupNorm: 3-123             [1, 256, 8, 8]            512
│    │    └─Dropout: 3-124               [1, 256, 8, 8]            --
│    │    └─Conv2d: 3-125                [1, 256, 8, 8]            590,080
│    │    └─Conv2d: 3-126                [1, 256, 8, 8]            196,864
│    └─ResBlock: 2-46                    [1, 256, 8, 8]            --
│    │    └─GroupNorm: 3-127             [1, 512, 8, 8]            1,024
│    │    └─Conv2d: 3-128                [1, 256, 8, 8]            1,179,904
│    │    └─Linear: 3-129                [1, 256]                  131,328
│    │    └─GroupNorm: 3-130             [1, 256, 8, 8]            512
│    │    └─Dropout: 3-131               [1, 256, 8, 8]            --
│    │    └─Conv2d: 3-132                [1, 256, 8, 8]            590,080
│    │    └─Conv2d: 3-133                [1, 256, 8, 8]            131,328
├─ModuleList: 1-26                       --                        (recursive)
│    └─Upsample: 2-47                    [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-134                [1, 256, 16, 16]          590,080
├─ModuleList: 1-27                       --                        (recursive)
│    └─ResBlock: 2-48                    [1, 256, 16, 16]          --
│    │    └─GroupNorm: 3-135             [1, 512, 16, 16]          1,024
│    │    └─Conv2d: 3-136                [1, 256, 16, 16]          1,179,904
│    │    └─Linear: 3-137                [1, 256]                  131,328
│    │    └─GroupNorm: 3-138             [1, 256, 16, 16]          512
│    │    └─Dropout: 3-139               [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-140                [1, 256, 16, 16]          590,080
│    │    └─Conv2d: 3-141                [1, 256, 16, 16]          131,328
│    └─ResBlock: 2-49                    [1, 256, 16, 16]          --
│    │    └─GroupNorm: 3-142             [1, 512, 16, 16]          1,024
│    │    └─Conv2d: 3-143                [1, 256, 16, 16]          1,179,904
│    │    └─Linear: 3-144                [1, 256]                  131,328
│    │    └─GroupNorm: 3-145             [1, 256, 16, 16]          512
│    │    └─Dropout: 3-146               [1, 256, 16, 16]          --
│    │    └─Conv2d: 3-147                [1, 256, 16, 16]          590,080
│    │    └─Conv2d: 3-148                [1, 256, 16, 16]          131,328
├─ModuleList: 1-26                       --                        (recursive)
│    └─Upsample: 2-50                    [1, 256, 32, 32]          --
│    │    └─Conv2d: 3-149                [1, 256, 32, 32]          590,080
├─ModuleList: 1-27                       --                        (recursive)
│    └─ResBlock: 2-51                    [1, 128, 32, 32]          --
│    │    └─GroupNorm: 3-150             [1, 384, 32, 32]          768
│    │    └─Conv2d: 3-151                [1, 128, 32, 32]          442,496
│    │    └─Linear: 3-152                [1, 128]                  65,664
│    │    └─GroupNorm: 3-153             [1, 128, 32, 32]          256
│    │    └─Dropout: 3-154               [1, 128, 32, 32]          --
│    │    └─Conv2d: 3-155                [1, 128, 32, 32]          147,584
│    │    └─Conv2d: 3-156                [1, 128, 32, 32]          49,280
│    └─ResBlock: 2-52                    [1, 128, 32, 32]          --
│    │    └─GroupNorm: 3-157             [1, 256, 32, 32]          512
│    │    └─Conv2d: 3-158                [1, 128, 32, 32]          295,040
│    │    └─Linear: 3-159                [1, 128]                  65,664
│    │    └─GroupNorm: 3-160             [1, 128, 32, 32]          256
│    │    └─Dropout: 3-161               [1, 128, 32, 32]          --
│    │    └─Conv2d: 3-162                [1, 128, 32, 32]          147,584
│    │    └─Conv2d: 3-163                [1, 128, 32, 32]          32,896
├─ModuleList: 1-26                       --                        (recursive)
│    └─Upsample: 2-53                    [1, 128, 64, 64]          --
│    │    └─Conv2d: 3-164                [1, 128, 64, 64]          147,584
├─ModuleList: 1-27                       --                        (recursive)
│    └─ResBlock: 2-54                    [1, 128, 64, 64]          --
│    │    └─GroupNorm: 3-165             [1, 256, 64, 64]          512
│    │    └─Conv2d: 3-166                [1, 128, 64, 64]          295,040
│    │    └─Linear: 3-167                [1, 128]                  65,664
│    │    └─GroupNorm: 3-168             [1, 128, 64, 64]          256
│    │    └─Dropout: 3-169               [1, 128, 64, 64]          --
│    │    └─Conv2d: 3-170                [1, 128, 64, 64]          147,584
│    │    └─Conv2d: 3-171                [1, 128, 64, 64]          32,896
│    └─ResBlock: 2-55                    [1, 128, 64, 64]          --
│    │    └─GroupNorm: 3-172             [1, 256, 64, 64]          512
│    │    └─Conv2d: 3-173                [1, 128, 64, 64]          295,040
│    │    └─Linear: 3-174                [1, 128]                  65,664
│    │    └─GroupNorm: 3-175             [1, 128, 64, 64]          256
│    │    └─Dropout: 3-176               [1, 128, 64, 64]          --
│    │    └─Conv2d: 3-177                [1, 128, 64, 64]          147,584
│    │    └─Conv2d: 3-178                [1, 128, 64, 64]          32,896
├─GroupNorm: 1-28                        [1, 128, 64, 64]          256
├─Conv2d: 1-29                           [1, 3, 64, 64]            387
==========================================================================================
Total params: 89,488,131
Trainable params: 89,488,131
Non-trainable params: 0
Total mult-adds (G): 12.34
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 138.44
Params size (MB): 353.75
Estimated Total Size (MB): 492.24
==========================================================================================
In [8]:
# -----------------------------------------------------------------------------
# 4) Training loop
# -----------------------------------------------------------------------------
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

epochs = 50
for epoch in range(epochs):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for x in pbar:
        x = x.to(device)
        bsz = x.size(0)

        # sample t ∈ [0,T) and noise
        t = torch.randint(0, diffusion.timesteps, (bsz,), device=device)
        noise = torch.randn_like(x)

        # diffuse
        x_t = diffusion.q_sample(x, t, noise)

        # predict noise with scaled t
        noise_pred = model(x_t, t.float() / diffusion.timesteps)

        # compute MSE loss
        loss = F.mse_loss(noise_pred, noise)
        optimizer.zero_grad()
        loss.backward()

        # <-- clamp gradients to prevent explosion -->
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        pbar.set_postfix(loss=loss.item())

    # display samples every 10 epochs
    if epoch % 1 == 0:
        samples = p_sample_loop(model, diffusion, (25,3,64,64))
        grid = make_grid((samples + 1)*0.5, nrow=5).clamp(0,1)
        np_grid = grid.permute(1,2,0).cpu().numpy()
        plt.figure(figsize=(5,5))
        plt.imshow(np_grid)
        plt.axis('off')
        plt.show()
Epoch 0: 100%|██████████| 469/469 [01:10<00:00,  6.67it/s, loss=0.0408]
No description has been provided for this image
Epoch 1: 100%|██████████| 469/469 [01:10<00:00,  6.63it/s, loss=0.0827]
No description has been provided for this image
Epoch 2: 100%|██████████| 469/469 [01:07<00:00,  6.94it/s, loss=0.0295]
No description has been provided for this image
Epoch 3: 100%|██████████| 469/469 [01:05<00:00,  7.21it/s, loss=0.0306]
No description has been provided for this image
Epoch 4: 100%|██████████| 469/469 [01:05<00:00,  7.18it/s, loss=0.0329]
No description has been provided for this image
Epoch 5: 100%|██████████| 469/469 [01:10<00:00,  6.65it/s, loss=0.0337]
No description has been provided for this image
Epoch 6: 100%|██████████| 469/469 [01:09<00:00,  6.72it/s, loss=0.0286]
No description has been provided for this image
Epoch 7: 100%|██████████| 469/469 [01:10<00:00,  6.65it/s, loss=0.0423]
No description has been provided for this image
Epoch 8: 100%|██████████| 469/469 [01:05<00:00,  7.11it/s, loss=0.0208]
No description has been provided for this image
Epoch 9: 100%|██████████| 469/469 [01:04<00:00,  7.25it/s, loss=0.02]  
No description has been provided for this image
Epoch 10: 100%|██████████| 469/469 [01:04<00:00,  7.27it/s, loss=0.0108]
No description has been provided for this image
Epoch 11: 100%|██████████| 469/469 [01:04<00:00,  7.24it/s, loss=0.0391]
No description has been provided for this image
Epoch 12: 100%|██████████| 469/469 [01:04<00:00,  7.24it/s, loss=0.0166]
No description has been provided for this image
Epoch 13: 100%|██████████| 469/469 [01:08<00:00,  6.87it/s, loss=0.0182]
No description has been provided for this image
Epoch 14: 100%|██████████| 469/469 [01:10<00:00,  6.69it/s, loss=0.0231]
No description has been provided for this image
Epoch 15: 100%|██████████| 469/469 [01:09<00:00,  6.71it/s, loss=0.032] 
No description has been provided for this image
Epoch 16: 100%|██████████| 469/469 [01:10<00:00,  6.68it/s, loss=0.0333]
No description has been provided for this image
Epoch 17: 100%|██████████| 469/469 [01:09<00:00,  6.72it/s, loss=0.0173]
No description has been provided for this image
Epoch 18: 100%|██████████| 469/469 [01:10<00:00,  6.65it/s, loss=0.0289]
No description has been provided for this image
Epoch 19: 100%|██████████| 469/469 [01:10<00:00,  6.63it/s, loss=0.0182]
No description has been provided for this image
Epoch 20: 100%|██████████| 469/469 [01:09<00:00,  6.71it/s, loss=0.0343]
No description has been provided for this image
Epoch 21: 100%|██████████| 469/469 [01:10<00:00,  6.67it/s, loss=0.0227] 
No description has been provided for this image
Epoch 22:  19%|█▊        | 87/469 [00:13<00:58,  6.58it/s, loss=0.0324]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[8], line 27
     25 loss = F.mse_loss(noise_pred, noise)
     26 optimizer.zero_grad()
---> 27 loss.backward()
     29 # <-- clamp gradients to prevent explosion -->
     30 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

File ~/.local/lib/python3.10/site-packages/torch/_tensor.py:581, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    571 if has_torch_function_unary(self):
    572     return handle_torch_function(
    573         Tensor.backward,
    574         (self,),
   (...)
    579         inputs=inputs,
    580     )
--> 581 torch.autograd.backward(
    582     self, gradient, retain_graph, create_graph, inputs=inputs
    583 )

File ~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    342     retain_graph = create_graph
    344 # The reason we repeat the same comment below is that
    345 # some Python versions print out the first line of a multi-line function
    346 # calls in the traceback and some print out the last line
--> 347 _engine_run_backward(
    348     tensors,
    349     grad_tensors_,
    350     retain_graph,
    351     create_graph,
    352     inputs,
    353     allow_unreachable=True,
    354     accumulate_grad=True,
    355 )

File ~/.local/lib/python3.10/site-packages/torch/autograd/graph.py:825, in _engine_run_backward(t_outputs, *args, **kwargs)
    823     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    824 try:
--> 825     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    826         t_outputs, *args, **kwargs
    827     )  # Calls into the C++ engine to run the backward pass
    828 finally:
    829     if attach_logging_hooks:

KeyboardInterrupt: 
In [11]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

# 1) Draw 100 samples from pure noise → reverse diffusion
#    shape = (N, C, H, W) = (100, 3, 128, 128)
samples = p_sample_loop(model, diffusion, (225, 3, 64, 64))

# 2) Un‑normalize from [–1,1] → [0,1]
samples = (samples + 1) * 0.5
samples = samples.clamp(0, 1)

# 3) Make a grid: 10 images per row → 10×10
grid = make_grid(samples, nrow=15)

# 4) Plot
plt.figure(figsize=(25,25))
# permute to H×W×C for plt
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("Generated Faces")
plt.show()
No description has been provided for this image

Stable Diffusion¶

In [1]:
# 1) Install the libraries (run once in your environment)
! pip install diffusers transformers accelerate safetensors --quiet
[notice] A new release of pip is available: 24.0 -> 25.0.1
[notice] To update, run: pip install --upgrade pip
In [1]:
from huggingface_hub import login

# Read the token from the file
with open("/home/kmcalist/QTM447/HFToken.txt", "r") as token_file:
    hf_token = token_file.read().strip()

# Authenticate using the token
login(hf_token)
In [2]:
import torch
from diffusers import StableDiffusion3Pipeline
import matplotlib.pyplot as plt

model_id = "stabilityai/stable-diffusion-3.5-medium"
pipe = StableDiffusion3Pipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,  # MMDiT performs best in bf16
    safety_checker=None          # if you want to skip the NSFW filter
)
pipe = pipe.to("cuda")
2025-04-17 12:38:41.666594: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-17 12:38:41.666619: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-17 12:38:41.667606: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-17 12:38:41.672577: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-17 12:38:42.496755: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Keyword arguments {'safety_checker': None} are not expected by StableDiffusion3Pipeline and will be ignored.
Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
In [4]:
# 3) Generate some images
prompts = [
    "Just a regular old celebrity wearing a hat in a realistic style.",
    "A cyberpunk dog samurai standing in neon rain",
    "15 students sitting in a classroom",
    "Dog pope"
]

# before you start, tell the pipeline to slice attention to save a bit more
pipe.enable_attention_slicing()

for prompt in prompts:
    # run just one prompt
    out = pipe(prompt, num_inference_steps=30, guidance_scale=7.5)
    img = out.images[0]

    # display immediately
    plt.figure(figsize=(4,4))
    plt.imshow(img)
    plt.axis("off")
    plt.show()

    # clean up
    del out, img
    torch.cuda.empty_cache()
  0%|          | 0/30 [00:00<?, ?it/s]
No description has been provided for this image
  0%|          | 0/30 [00:00<?, ?it/s]
No description has been provided for this image
  0%|          | 0/30 [00:00<?, ?it/s]
No description has been provided for this image
  0%|          | 0/30 [00:00<?, ?it/s]
No description has been provided for this image